#include <slang/syntax/SyntaxTree.h>
#include <slang/syntax/SyntaxPrinter.h>
#include <slang/syntax/SyntaxVisitor.h>
#include <slang/syntax/SyntaxNode.h>

#include <iostream>
#include <cassert>

class NodePrintingVisitor {
  using SyntaxNode = slang::SyntaxNode;
  using Token = slang::Token;

  size_t depth = 0;
  
public:
    /// Visit the provided node, of static type T.
    template<typename T>
    void visit(const T& t) {
            visitDefault(t);
    }

    /// The default handler invoked when no visit() method is overriden for a particular type.
    /// Will visit all child nodes by default.
    void visitDefault(const SyntaxNode& node) {
      if (node.kind != slang::SyntaxKind::SyntaxList &&
	  node.kind != slang::SyntaxKind::TokenList &&
	  node.kind != slang::SyntaxKind::SeparatedList) {
	std::string spaces(depth * 3, ' ');
	std::cout << spaces << node.kind;
	if (node.parent) {
	  std::cout << "(parent: " << node.parent->kind << ")";
	}
	if (node.kind == slang::SyntaxKind::FunctionPort) {
	  const auto &port = node.as<slang::FunctionPortSyntax>();
	  std::cout << "(direction: " << port.direction.rawText() << ")";
	}
	if (node.kind == slang::SyntaxKind::Declarator) {
	  const auto &decl = node.as<slang::DeclaratorSyntax>();
	  std::cout << "(name: " << decl.name.rawText() << ")";
	}
	std::cout << std::endl;
      }

        for (uint32_t i = 0; i < node.getChildCount(); i++) {
            auto child = node.childNode(i);
            if (child) {
	      depth++;
	      child->visit(*this);
	      depth--;
	    }
        }
    }

    /// The default handler invoked when visiting an invalid node.
    void visitInvalid(const SyntaxNode&) {}

private:
    // This is to make things compile if the derived class doesn't provide an implementation.
    void visitToken(Token) {}
};

class ScanMethodVisitor: public slang::SyntaxVisitor<ScanMethodVisitor>
{
  using ClassMethodDeclaration = slang::ClassMethodDeclarationSyntax;
public:
  void handle(const ClassMethodDeclaration &decl)
  {
    bool isTask;
    if (decl.declaration->kind == slang::SyntaxKind::TaskDeclaration) {
      isTask = true;
    } else if (decl.declaration->kind == slang::SyntaxKind::FunctionDeclaration) {
      isTask = false;
    } else {
      assert(false);
    }
    slang::NameSyntax *name = decl.declaration->prototype->name;
    if (slang::IdentifierNameSyntax::isKind(name->kind)) {
      slang::IdentifierNameSyntax &id = name->as<slang::IdentifierNameSyntax>();
      std::cout << "Class method: " << (isTask?"task ":"function ") << id.toString() << std::endl;
    } else {
      std::cout << "[Unknown class decl]" << std::endl;
    }
    if (!isTask) {
      auto returnType = decl.declaration->prototype->returnType;
      std::cout << "Return type: " << returnType->toString() << std::endl;
    }
  }
};

class FunctionToTaskVisitor: public slang::SyntaxRewriter<FunctionToTaskVisitor>
{
public:
  void handle(const slang::ClassMethodDeclarationSyntax &decl)
  {
    bool isTask;
    if (decl.declaration->kind == slang::SyntaxKind::TaskDeclaration) {
      isTask = true;
    } else if (decl.declaration->kind == slang::SyntaxKind::FunctionDeclaration) {
      isTask = false;
    } else {
      assert(false);
    }

    if (!isTask) {
      auto returnType = decl.declaration->prototype->returnType;
      auto retTyStr = returnType->toString();
      // go for ports
      auto &ports = decl.declaration->prototype->portList->ports;
      insertBefore(*ports[0], parse("output " + retTyStr + " retVal,"));
      decl.declaration->prototype->keyword.kind = slang::TokenKind::TaskKeyword;
      replace(*decl.declaration->prototype->returnType, parse(""));
      decl.declaration->end.kind = slang::TokenKind::EndTaskKeyword;

      // continue to rewrite the return statements
      visitDefault(decl);
    }
  }

  void handle(const slang::ReturnStatementSyntax &retn)
  {
    if (retn.returnValue != nullptr) {
      replace(retn, parse("retVal = " + retn.returnValue->toString() + ";\nreturn;"));
    }
  }
};

int main(int argc, char *argv[])
{
  if (argc != 2) {
    std::cerr << "usage: " << argv[0] << " file" << std::endl;
    return 1;
  }

  auto tree = slang::SyntaxTree::fromFile(argv[1]);

  std::cout << "============= print node =================" << std::endl;
  NodePrintingVisitor visitor;
  tree->root().visit(visitor);

  std::cout << "============= scan visitor ===============" << std::endl;
  ScanMethodVisitor scan_method;
  tree->root().visit(scan_method);

  std::cout << "============= print file ================" << std::endl;
  std::string s = slang::SyntaxPrinter::printFile(*tree);
  std::cout << s << std::endl;

  std::cout << "============= rewrite ===================" << std::endl;
  tree = FunctionToTaskVisitor().transform(tree);
  std::cout << slang::SyntaxPrinter::printFile(*tree) << std::endl;
}
